import os
import random
import torch
import matplotlib.pyplot as plt
import torchvision
from torchvision.datasets import CIFAR10

import eecs598


def _extract_tensors(dset, num=None, x_dtype=torch.float32):
  """
  Extract the data and labels from a CIFAR10 dataset object and convert them to
  tensors.

  Input:
  - dset: A torchvision.datasets.CIFAR10 object
  - num: Optional. If provided, the number of samples to keep.
  - x_dtype: Optional. data type of the input image

  Returns:
  - x: `x_dtype` tensor of shape (N, 3, 32, 32)
  - y: int64 tensor of shape (N,)
  """
  x = torch.tensor(dset.data, dtype=x_dtype).permute(0, 3, 1, 2).div_(255)
  y = torch.tensor(dset.targets, dtype=torch.int64)
  if num is not None:
    if num <= 0 or num > x.shape[0]:
      raise ValueError('Invalid value num=%d; must be in the range [0, %d]'
                       % (num, x.shape[0]))
    x = x[:num].clone()
    y = y[:num].clone()
  return x, y


def cifar10(num_train=None, num_test=None, x_dtype=torch.float32):
  """
  Return the CIFAR10 dataset, automatically downloading it if necessary.
  This function can also subsample the dataset.

  Inputs:
  - num_train: [Optional] How many samples to keep from the training set.
    If not provided, then keep the entire training set.
  - num_test: [Optional] How many samples to keep from the test set.
    If not provided, then keep the entire test set.
  - x_dtype: [Optional] Data type of the input image

  Returns:
  - x_train: `x_dtype` tensor of shape (num_train, 3, 32, 32)
  - y_train: int64 tensor of shape (num_train, 3, 32, 32)
  - x_test: `x_dtype` tensor of shape (num_test, 3, 32, 32)
  - y_test: int64 tensor of shape (num_test, 3, 32, 32)
  """
  download = not os.path.isdir('cifar-10-batches-py')
  dset_train = CIFAR10(root='.', download=download, train=True)
  dset_test = CIFAR10(root='.', train=False)
  x_train, y_train = _extract_tensors(dset_train, num_train, x_dtype)
  x_test, y_test = _extract_tensors(dset_test, num_test, x_dtype)
 
  return x_train, y_train, x_test, y_test


def preprocess_cifar10(
    cuda=True,
    show_examples=True,
    bias_trick=False,
    validation_ratio=0.2,
    dtype=torch.float32):
  """
  Returns a preprocessed version of the CIFAR10 dataset, automatically
  downloading if necessary. We perform the following steps:

  (0) [Optional] Visualize some images from the dataset
  (1) Normalize the data by subtracting the mean
  (2) Reshape each image of shape (3, 32, 32) into a vector of shape (3072,)
  (3) [Optional] Bias trick: add an extra dimension of ones to the data
  (4) Carve out a validation set from the training set

  Inputs:
  - cuda: If true, move the entire dataset to the GPU
  - validation_ratio: Float in the range (0, 1) giving the fraction of the train
    set to reserve for validation
  - bias_trick: Boolean telling whether or not to apply the bias trick
  - show_examples: Boolean telling whether or not to visualize data samples
  - dtype: Optional, data type of the input image X

  Returns a dictionary with the following keys:
  - 'X_train': `dtype` tensor of shape (N_train, D) giving training images
  - 'X_val': `dtype` tensor of shape (N_val, D) giving val images
  - 'X_test': `dtype` tensor of shape (N_test, D) giving test images
  - 'y_train': int64 tensor of shape (N_train,) giving training labels
  - 'y_val': int64 tensor of shape (N_val,) giving val labels
  - 'y_test': int64 tensor of shape (N_test,) giving test labels

  N_train, N_val, and N_test are the number of examples in the train, val, and
  test sets respectively. The precise values of N_train and N_val are determined
  by the input parameter validation_ratio. D is the dimension of the image data;
  if bias_trick is False, then D = 32 * 32 * 3 = 3072;
  if bias_trick is True then D = 1 + 32 * 32 * 3 = 3073.
  """
  X_train, y_train, X_test, y_test = cifar10(x_dtype=dtype)

  # Move data to the GPU
  if cuda:
    X_train = X_train.cuda()
    y_train = y_train.cuda()
    X_test = X_test.cuda()
    y_test = y_test.cuda()

  # 0. Visualize some examples from the dataset.
  if show_examples:
    classes = [
        'plane', 'car', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]
    samples_per_class = 12
    samples = []
    eecs598.reset_seed(0)
    for y, cls in enumerate(classes):
        plt.text(-4, 34 * y + 18, cls, ha='right')
        idxs, = (y_train == y).nonzero(as_tuple=True)
        for i in range(samples_per_class):
            idx = idxs[random.randrange(idxs.shape[0])].item()
            samples.append(X_train[idx])
    img = torchvision.utils.make_grid(samples, nrow=samples_per_class)
    plt.imshow(eecs598.tensor_to_image(img))
    plt.axis('off')
    plt.show()

  # 1. Normalize the data: subtract the mean RGB (zero mean)
  mean_image = X_train.mean(dim=(0, 2, 3), keepdim=True)
  X_train -= mean_image
  X_test -= mean_image

  # 2. Reshape the image data into rows
  X_train = X_train.reshape(X_train.shape[0], -1)
  X_test = X_test.reshape(X_test.shape[0], -1)

  # 3. Add bias dimension and transform into columns
  if bias_trick:
    ones_train = torch.ones(X_train.shape[0], 1, device=X_train.device)
    X_train = torch.cat([X_train, ones_train], dim=1)
    ones_test = torch.ones(X_test.shape[0], 1, device=X_test.device)
    X_test = torch.cat([X_test, ones_test], dim=1)

  # 4. take the validation set from the training set
  # Note: It should not be taken from the test set
  # For random permumation, you can use torch.randperm or torch.randint
  # But, for this homework, we use slicing instead.
  num_training = int( X_train.shape[0] * (1.0 - validation_ratio) )
  num_validation = X_train.shape[0] - num_training

  # return the dataset
  data_dict = {}
  data_dict['X_val'] = X_train[num_training:num_training + num_validation]
  data_dict['y_val'] = y_train[num_training:num_training + num_validation]
  data_dict['X_train'] = X_train[0:num_training]
  data_dict['y_train'] = y_train[0:num_training]

  data_dict['X_test'] = X_test
  data_dict['y_test'] = y_test
  return data_dict

